{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convolution demos"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we use the batch, multi-channel convolution operation implemented in Numpy (that you can find [here](../lincoln/lincoln/conv.py)) to train a small convolutional neural network to more than 90% accuracy on MNIST."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import lincoln\n",
    "from lincoln.layers import Dense\n",
    "from lincoln.losses import SoftmaxCrossEntropy, MeanSquaredError\n",
    "from lincoln.optimizers import Optimizer, SGD, SGDMomentum\n",
    "from lincoln.activations import Sigmoid, Tanh, Linear, ReLU\n",
    "from lincoln.network import NeuralNetwork\n",
    "from lincoln.train import Trainer\n",
    "from lincoln.utils import mnist\n",
    "from lincoln.layers import Conv2D\n",
    "\n",
    "X_train, y_train, X_test, y_test = mnist.load()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test = X_train - np.mean(X_train), X_test - np.mean(X_train)\n",
    "X_train, X_test = X_train / np.std(X_train), X_test / np.std(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_conv, X_test_conv = X_train.reshape(-1, 1, 28, 28), X_test.reshape(-1, 1, 28, 28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_labels = len(y_train)\n",
    "train_labels = np.zeros((num_labels, 10))\n",
    "for i in range(num_labels):\n",
    "    train_labels[i][y_train[i]] = 1\n",
    "\n",
    "num_labels = len(y_test)\n",
    "test_labels = np.zeros((num_labels, 10))\n",
    "for i in range(num_labels):\n",
    "    test_labels[i][y_test[i]] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_accuracy_model(model, test_set):\n",
    "    return print(f'''The model validation accuracy is: \n",
    "    {np.equal(np.argmax(model.forward(test_set, inference=True), axis=1), y_test).sum() * 100.0 / test_set.shape[0]:.2f}%''')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CNN from scratch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation accuracy after 100 batches is 86.85%\n",
      "Validation accuracy after 200 batches is 83.78%\n",
      "Validation accuracy after 300 batches is 90.42%\n",
      "Validation accuracy after 400 batches is 89.08%\n",
      "Validation accuracy after 500 batches is 90.01%\n",
      "Validation accuracy after 600 batches is 90.57%\n",
      "Validation accuracy after 700 batches is 84.27%\n",
      "Validation accuracy after 800 batches is 91.85%\n",
      "Validation accuracy after 900 batches is 92.50%\n",
      "Validation loss after 1 epochs is 3.615\n"
     ]
    }
   ],
   "source": [
    "model = NeuralNetwork(\n",
    "    layers=[Conv2D(out_channels=16,\n",
    "                   param_size=5,\n",
    "                   dropout=0.8,\n",
    "                   weight_init=\"glorot\",\n",
    "                   flatten=True,\n",
    "                  activation=Tanh()),\n",
    "            Dense(neurons=10, \n",
    "                  activation=Linear())],\n",
    "            loss = SoftmaxCrossEntropy(), \n",
    "seed=20190402)\n",
    "\n",
    "trainer = Trainer(model, SGDMomentum(lr = 0.1, momentum=0.9))\n",
    "trainer.fit(X_train_conv, train_labels, X_test_conv, test_labels,\n",
    "            epochs = 1,\n",
    "            eval_every = 1,\n",
    "            seed=20190402,\n",
    "            batch_size=60,\n",
    "            conv_testing=True);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model validation accuracy is: \n",
      "    90.31%\n"
     ]
    }
   ],
   "source": [
    "calc_accuracy_model(model, X_test_conv)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
